
import time
import sys
import numpy as np
import os
from torch import nn

from utils import *
from simSweep import *
from Tx import *
import device



if __name__ == "__main__":
	#*************************HEADER***********************#
	startTime = time.time()
	np.random.seed(1)
	args = parsing_def()
	sys.path.insert(0, './config')
	config_module = __import__('config_{}'.format(args.config))
	cfg= config_module.config
	#******************************************************#

	tx = Tx(mod=cfg['eqs']['mod'])
	chInGeneral = tx.run(int(cfg['eqs']['dataSize']))
	#********************************#
	# Switching which EQ is ON 
	#********************************#
	if (cfg['eqs']['viterbiOn'] or 
	cfg['eqs']['fwdBwdOn'] or 
	cfg['eqs']['fwdOn']):
		sweep = simSweep(
						chSbr=cfg['eqs']['chSBR'], 
						eqSbr=cfg['eqs']['eqSBR'],
						snrList=cfg['eqs']['snrList'], 
						originData=chInGeneral,
						mod=cfg['eqs']['mod'],
						flagN=cfg['eqs']['noiseFlag'], 
						stateGen=True,
						)
	else:
		sweep = simSweep(
						chSbr=cfg['eqs']['chSBR'], 
						eqSbr=cfg['eqs']['eqSBR'],
						snrList=cfg['eqs']['snrList'], 
						originData=chInGeneral,
						mod=cfg['eqs']['mod'],
						flagN=cfg['eqs']['noiseFlag'], 
						stateGen=False,
						)

	if cfg['eqs']['firOn']:
		firBerList = sweep.fir(ffeTapNum=cfg['eqs']['firTapNumForFir'])
		print("",flush=True)
	if cfg['eqs']['dfeOn']:
		dfeBerList = sweep.dfe(dfeTapNum=cfg['eqs']['dfeTapNumForDfe'])
		print("",flush=True)
	if cfg['eqs']['firDfeOn']:
		ffeMaxTapNum = sweep.firDfeSearchMaxTap(
							ffeTapNum = cfg['eqs']['firTapNumForFirDfe'],
							dfeTapNum = cfg['eqs']['dfeTapNumForFirDfe']
							)
		firDfeBerList = sweep.firDfe(
							ffeTapNum=cfg['eqs']['firTapNumForFirDfe'], 
							ffeMaxTapNum=ffeMaxTapNum, 
							dfeTapNum=cfg['eqs']['dfeTapNumForFirDfe'],
							) 
		print("",flush=True)
	if cfg['eqs']['viterbiOn']:
		sweep.viterbiOverlap(blockSizeList=cfg['eqs']['blockSizeList']) #default
		print("",flush=True)
	if cfg['eqs']['fwdBwdOn']:
		fwdBwdBerList, fwdBwdProb = sweep.fwdBwd(fwdBwdLen=cfg['eqs']['fwdBwdLen'],snrOvrd=cfg['eqs']['fwdBwdSnrOvrd'])
		print("",flush=True)
	if cfg['eqs']['fwdOn']:
		fwdBerList, fwdProb = sweep.fwd(fwdLen=cfg['eqs']['fwdLen'])
		print("",flush=True)
	if cfg['eqs']['neqOn']:
		delay = int((cfg['eval']['inSize'])/4)
		delayOffset = -list(cfg['eqs']['chSBR']).index(max(cfg['eqs']['chSBR']))
		for modelFile in cfg['eqs']['modelFileList']:
			if os.path.exists(modelFile):
				nEQLoad = torch.load(modelFile)
				nEQLoad = nEQLoad.to(device.device)
				nnFwdBwdBerList = sweep.nnFwdBwd(
											neuralEQ=nEQLoad, 
											lossFn=nn.MSELoss(), 
											batchSize=8192, 
											inSize=cfg['eqs']['inSize'], 
											outSize=cfg['eqs']['outSize'], 
											delay=delay+delayOffset,
											)
				print("",flush=True)
			else:
				print("")
				print("Not exists %s"%modelFile)
				print("")


	timeSim = (time.time()-startTime)/60. # Unit: minuite
	print(f"Total simulation time: {timeSim} mins")

	if cfg['eqs']['plot']:
		plt.figure()
		plt.xlabel('SNR (dB)',fontsize=13)
		plt.xticks(fontsize=13)
		plt.ylabel('BER ',fontsize=13)
		#plt.ylim([0.8,10])
		plt.yscale('log')
		plt.grid(True)

		if cfg['eqs']['firOn']:
			plt.plot(cfg['eqs']['snrList'], firBerList, '-o', label='FFE')
		if cfg['eqs']['dfeOn']:
			plt.plot(cfg['eqs']['snrList'], dfeBerList, '-o', label='DFE')
		if cfg['eqs']['firDfeOn']:
			plt.plot(cfg['eqs']['snrList'], firDfeBerList, '-o', label='FFE+DFE')
		if cfg['eqs']['fwdBwdOn']:
			plt.plot(cfg['eqs']['snrList'], fwdBwdBerList, '-o', label='FwdBwd')
		if cfg['eqs']['fwdOn']:
			plt.plot(cfg['eqs']['snrList'], fwdBerList, '-o', label='Fwd')
		if cfg['eqs']['neqOn']:
			plt.plot(cfg['eqs']['snrList'], nnFwdBwdBerList, '-o', label='nEQ')
		
		plt.legend(loc='best')
		plt.show()

	
